import torch
import numpy as np
from utils_others import compute_feature_by_dataloader
from utils_data_ic import ImageDataset

class Buffer:
    """
    The memory buffer of rehearsal method.
    """
    def __init__(self, buffer_size:int, class_each_task:list, buffer_batch_size:int, 
                 sampling_alg:str, is_fix_budget_each_class:bool, is_mix_er:bool):
        self.buffer_size = buffer_size
        self.class_each_task = class_each_task
        self.n_task = len(class_each_task)
        self.buffer_batch_size = buffer_batch_size
        self.sampling_alg = sampling_alg
        self.is_fix_budget_each_class = is_fix_budget_each_class
        self.is_mix_er = is_mix_er
        self.X = [[] for _ in range(self.n_task)] # for each task
        self.y = [[] for _ in range(self.n_task)] # for each task

    def __len__(self) -> int:
        return np.sum([len(task_X) for task_X in self.X])
    
    def get_old_size(self, task_id) -> int:
        return np.sum([len(task_X) for task_X in self.X[:task_id]])
    
    def get_location(self, idx) -> tuple:
        cum_sum_list = [0]
        cum_sum_list.extend(np.cumsum([len(task_X) for task_X in self.X]))
        for t_id, cum_sum in enumerate(cum_sum_list):
            if idx<cum_sum:
                return t_id-1, idx-cum_sum_list[t_id-1]
        raise ValueError('Invalud idx %d'%(idx))
    
    def finish_end_task(self, task_id, train_loader, model, cur_class_each_task, task_name):
        if self.sampling_alg == 'herding':
            assert task_name == 'IC','Not Implemented for other tasks when using herding algorithm'
            if self.is_fix_budget_each_class:
                n_sample_each_class = self.buffer_size//sum([len(cur_class_each_task[t_id]) for t_id in range(self.n_task)])
            else:
                n_sample_each_class = self.buffer_size//sum([len(cur_class_each_task[t_id]) for t_id in range(task_id+1)])
            if task_id>0:
                for t_id in range(task_id):
                    sub_sample_X = []
                    sub_sample_y = []
                    sample_cnt_dict = {class_id:0 for class_id in cur_class_each_task[t_id]}
                    for s_id, _y in enumerate(self.y[t_id]):
                        assert isinstance(_y, int), 'Not Implemented for Sequential Labeling Tasks (NER)' 
                        if sample_cnt_dict[_y] < n_sample_each_class:
                            sub_sample_X.append(self.X[t_id][s_id])
                            sub_sample_y.append(self.y[t_id][s_id])
                            sample_cnt_dict[_y] += 1
                    self.X[t_id] = sub_sample_X
                    self.y[t_id] = sub_sample_y

            idx_list, features_matrix, y_list = compute_feature_by_dataloader(train_loader, 
                                                                                model,
                                                                                is_normalize=True,
                                                                                return_idx=True)
            
            for class_idx in cur_class_each_task[task_id]:
                class_mask = np.equal(y_list, class_idx)
                class_idx_list = idx_list[class_mask]
                class_y_list = y_list[class_mask]
                class_feat = features_matrix[class_mask]
                class_feat_mean = np.mean(class_feat,axis=0).reshape(1,-1)
                running_feat_sum = np.zeros_like(class_feat_mean)
                cnt_sample = 0
                while cnt_sample<n_sample_each_class and cnt_sample<class_mask.shape[0]:
                    # select sample
                    dist_all = np.linalg.norm((class_feat_mean-(running_feat_sum+class_feat)/(cnt_sample+1)),ord=2,axis=-1)
                    min_idx = np.argmin(dist_all)
                    if hasattr(train_loader.dataset,'dataset'):
                        if isinstance(train_loader.dataset.dataset.data[0],str):
                        # if isinstance(train_loader.dataset.dataset,ImageDataset):
                            X_select = train_loader.dataset.dataset.data[class_idx_list[min_idx]]
                        else:
                            X_select = train_loader.dataset.dataset.__getitem__(class_idx_list[min_idx],is_transform=False)[1]
                    else:
                        X_select = train_loader.dataset.__getitem__(class_idx_list[min_idx],is_transform=False)[1]
                    y_select = class_y_list[min_idx]

                    # add new sample
                    if isinstance(X_select,torch.Tensor):
                        self.X[task_id].append(list(X_select.cpu().numpy()))
                    else:
                        self.X[task_id].append(X_select)
                        
                    if len(y_select.shape)==0:
                        self.y[task_id].append(y_select.item())
                    elif len(y_select.shape)==1:
                        self.y[task_id].append(list(y_select.cpu().numpy()))
                    else:
                        raise NotImplementedError()

                    # update variable
                    cnt_sample += 1
                    running_feat_sum += class_feat[min_idx:min_idx+1]
                    class_feat[min_idx] = class_feat[min_idx]+1e6

            assert self.__len__()<= self.buffer_size
         
    def init_begin_task(self, task_id):
        '''
            initalize for the i-th task 
        '''
        if self.sampling_alg == 'reservior':

            self.buffer_size_each_task = [0]*self.n_task

            # option 1 :Allocate according to number of classes each task
            # Delete some old tasks' samples ( 0 ~ task_id-1 ), the sum should be <= buffer size
            for t_id in range(task_id):
                adjust_task_size = self.buffer_size*np.sum(self.class_each_task[t_id])//np.sum(self.class_each_task[:task_id])
                self.buffer_size_each_task[t_id] = adjust_task_size
                if len(self.X[t_id]) <= adjust_task_size:
                    continue
                self.X[t_id] = self.X[t_id][-adjust_task_size:]
                self.y[t_id] = self.y[t_id][-adjust_task_size:]
            # Pre-allocate the space for the current tasks (not included in the buffer size)
            self.buffer_size_each_task[task_id] = self.buffer_size*np.sum(self.class_each_task[task_id])//np.sum(self.class_each_task[:task_id+1])
            
            # option 2: Allocate memory to each task equally
            # # Delete some old tasks' samples ( 0 ~ task_id-1 ), the sum should be <= buffer size
            # for t_id in range(task_id):
            #     adjust_task_size = self.buffer_size//task_id
            #     self.buffer_size_each_task[t_id] = adjust_task_size
            #     if len(self.X[t_id]) <= adjust_task_size:
            #         continue
            #     self.X[t_id] = self.X[t_id][-adjust_task_size:]
            #     self.y[t_id] = self.y[t_id][-adjust_task_size:]
            # # Pre-allocate the space for the current tasks (not included in the buffer size)
            # self.buffer_size_each_task[task_id] = self.buffer_size/(task_id+1)

            assert np.sum(self.buffer_size_each_task[:task_id]) <= self.buffer_size

            self.num_seen_sample = 0 # for reservoir sampling in the current task

    def get_buffer_batch(self, task_id=None, select_idx=None):
        '''
            Randomly sample a batch of data
            if task_id is not None:  from the old tasks ( 0 ~ task_id-1 );
            else : from all (seen) tasks
        '''
        batch_size = self.buffer_batch_size//2 if (task_id is not None) and (task_id>0) and (self.is_mix_er) else self.buffer_batch_size
        if select_idx is None:
            if task_id is None:
                select_idx = np.random.choice(self.__len__(), batch_size)
            else:
                select_idx = np.random.choice(self.get_old_size(task_id), batch_size)
        batch_data = []
        for _idx in select_idx:
            t_id, s_id = self.get_location(_idx)
            batch_data.append([_idx,self.X[t_id][s_id],self.y[t_id][s_id]]) # [(idx, X, y),...]
        return batch_data
    
    def get_buffer_all(self, task_id=None):
        '''
            Randomly sample a batch of data from the old tasks ( 0 ~ task_id-1 )
        '''
        if task_id==None:
            select_idx = list(range(self.__len__()))
        else:
            select_idx = list(range(np.sum([len(task_X) for task_X in self.X[:task_id]])))
        batch_data = []
        for _idx in select_idx:
            t_id, s_id = self.get_location(_idx)
            batch_data.append([_idx,self.X[t_id][s_id],self.y[t_id][s_id]]) # [(idx, X, y),...]
        return batch_data
    
    def update_buffer_batch(self, batch_X, batch_y, task_id):
        num_sample = batch_y.shape[0]
        for i in range(num_sample):
            self.num_seen_sample += 1
            if len(self.X[task_id]) < self.buffer_size_each_task[task_id]:
                # add new sample
                if isinstance(batch_X,torch.Tensor):
                    self.X[task_id].append(list(batch_X[i].cpu().numpy()))
                else:
                    self.X[task_id].append(batch_X[i])
                    
                if len(batch_y.shape)==1:
                    self.y[task_id].append(batch_y[i].item())
                elif len(batch_y.shape)==2:
                    self.y[task_id].append(list(batch_y[i].cpu().numpy()))
                else:
                    raise NotImplementedError()
            else:
                # replace sample
                replace_idx = np.random.randint(0, self.num_seen_sample + 1)
                if replace_idx < len(self.X[task_id]):
                    if isinstance(batch_X,torch.Tensor):
                        self.X[task_id][replace_idx] = list(batch_X[i].cpu().numpy())
                    else:
                        self.X[task_id][replace_idx] = batch_X[i]

                    if len(batch_y.shape)==1:
                        self.y[task_id][replace_idx] = batch_y[i].item()
                    elif len(batch_y.shape)==2:
                        self.y[task_id][replace_idx] = list(batch_y[i].cpu().numpy())
                    else:
                        raise NotImplementedError()
